classification model interpretation using explainer package
Ramtin
explainer_tutorial.RmdThis is an R Markdown Notebook. When you execute code within the notebook, the results appear beneath the code.
Try executing this chunk by clicking the Run button within the chunk or by placing your cursor inside it and pressing Ctrl+Shift+Enter.
This is an example on how to use the package “explainer” developed by Ramtin Zargari Marandi (email:ramtin.zargari.marandi@regionh.dk)
Loading a dataset and training a machine learning model
This first code chunk loads a dataset and creates a binary classification task and train a “random forest” model using mlr3 package.
Sys.setenv(LANG = "en") # change R language to English!
RNGkind("L'Ecuyer-CMRG") # change to L'Ecuyer-CMRG in case it uses default "Mersenne-Twister"
library("explainer")## Loading required package: cowplot
## Loading required package: data.table
## Loading required package: dplyr
## Warning: package 'dplyr' was built under R version 4.1.3
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:data.table':
##
## between, first, last
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 4.1.3
## Loading required package: ggpmisc
## Warning: package 'ggpmisc' was built under R version 4.1.3
## Loading required package: ggpp
## Warning: package 'ggpp' was built under R version 4.1.3
##
## Attaching package: 'ggpp'
## The following object is masked from 'package:ggplot2':
##
## annotate
## Loading required package: ggpubr
##
## Attaching package: 'ggpubr'
## The following object is masked from 'package:cowplot':
##
## get_legend
## Loading required package: mlr3
## Loading required package: mlr3learners
## Warning: package 'mlr3learners' was built under R version 4.1.3
## Loading required package: knitr
## Warning: package 'knitr' was built under R version 4.1.3
## Loading required package: broom
## Warning: package 'broom' was built under R version 4.1.3
## Loading required package: cvms
## Warning: package 'cvms' was built under R version 4.1.3
##
## Attaching package: 'cvms'
## The following object is masked from 'package:ggpubr':
##
## font
## Loading required package: egg
## Loading required package: gridExtra
##
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
##
## combine
##
## Attaching package: 'egg'
## The following object is masked from 'package:ggpubr':
##
## ggarrange
## Loading required package: forcats
## Loading required package: iml
## Warning: package 'iml' was built under R version 4.1.3
## Loading required package: mlr3viz
## Warning: package 'mlr3viz' was built under R version 4.1.3
## Loading required package: plotROC
## Warning: package 'plotROC' was built under R version 4.1.3
## Loading required package: psych
## Warning: package 'psych' was built under R version 4.1.3
##
## Attaching package: 'psych'
## The following objects are masked from 'package:ggplot2':
##
## %+%, alpha
## Loading required package: reshape2
##
## Attaching package: 'reshape2'
## The following objects are masked from 'package:data.table':
##
## dcast, melt
## Loading required package: writexl
## Loading required package: magrittr
## Loading required package: tidyverse
## Warning: package 'tidyverse' was built under R version 4.1.3
## -- Attaching packages --------------------------------------- tidyverse
## 1.3.2 --
## v tibble 3.1.8 v purrr 0.3.4
## v tidyr 1.2.0 v stringr 1.4.0
## v readr 2.1.2
## Warning: package 'tibble' was built under R version 4.1.3
## Warning: package 'tidyr' was built under R version 4.1.3
## Warning: package 'readr' was built under R version 4.1.3
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x psych::%+%() masks ggplot2::%+%()
## x psych::alpha() masks ggplot2::alpha()
## x ggpp::annotate() masks ggplot2::annotate()
## x dplyr::between() masks data.table::between()
## x gridExtra::combine() masks dplyr::combine()
## x tidyr::extract() masks magrittr::extract()
## x dplyr::filter() masks stats::filter()
## x dplyr::first() masks data.table::first()
## x dplyr::lag() masks stats::lag()
## x dplyr::last() masks data.table::last()
## x purrr::set_names() masks magrittr::set_names()
## x purrr::transpose() masks data.table::transpose()
## Loading required package: plotly
## Warning: package 'plotly' was built under R version 4.1.3
##
## Attaching package: 'plotly'
##
## The following object is masked from 'package:ggplot2':
##
## last_plot
##
## The following object is masked from 'package:stats':
##
## filter
##
## The following object is masked from 'package:graphics':
##
## layout
# set seed for reproducibility
seed <- 246
set.seed(seed)
# set TRUE if you have dataset if not set it to FALSE
data_availablity <- FALSE
# if we have a dataset to use here
if (data_availablity==FALSE){
# if you don't have a dataset you can try the following publicly available dataset
# load the BreastCancer data from the mlbench package
data("BreastCancer", package = "mlbench")
# keep the target column as "Class"
target_col <- "Class"
# change the positive class to "malignant"
positive_class <- "malignant"
# keep only the predictor variables and outcome
mydata <- BreastCancer[, -1] # 1 is ID
# remove rows with missing values
mydata <- na.omit(mydata)
# create a vector of sex categories
sex <- sample(c("Male", "Female"), size = nrow(mydata), replace = TRUE)
# create a vector of sex categories
mydata$age <- as.numeric(sample(seq(18,60), size = nrow(mydata), replace = TRUE))
# add a sex column to the mydata data frame (for fairness analysis)
mydata$sex <- factor(sex, levels = c("Male", "Female"), labels = c(1, 0))
}
# create a classification task
maintask <- mlr3::TaskClassif$new(id = "my_classification_task",
backend = mydata,
target = target_col,
positive = positive_class)
# create a train-test split
set.seed(seed)
splits <- mlr3::partition(maintask)
# add a learner (machine learning model base)
# library("mlr3learners")
library("mlr3extralearners")##
## Attaching package: 'mlr3extralearners'
##
## The following objects are masked from 'package:mlr3':
##
## lrn, lrns
# mlr_learners$get("classif.randomForest")
# here we use random forest for example (you can use any other available model)
mylrn <- mlr3::lrn("classif.randomForest", predict_type = "prob") # , id = "mymodel"
# train the model
mylrn$train(maintask, splits$train)
# make predictions on new data
mylrn$predict(maintask, splits$test)## <PredictionClassif> for 226 observations:
## row_ids truth response prob.malignant prob.benign
## 2 benign malignant 0.790 0.210
## 3 benign benign 0.000 1.000
## 4 benign malignant 0.920 0.080
## ---
## 655 malignant malignant 0.884 0.116
## 665 malignant malignant 1.000 0.000
## 683 malignant malignant 0.902 0.098
SHAP analysis to extract feature (variable) impacts on predictions
The following code chunk uses eSHAP_plot function to estimate SHAP values for the test set and create an interactive SHAP plot. This is an enhanced SHAP plot that means it provides additional information such as whether the predictions were correct (TP or TN). The color mapping provides enhanced visual inspection of the SHAP plot.
# enhanced SHAP plot
SHAP_output <- eSHAP_plot(task = maintask,
trained_model = mylrn,
splits = splits,
sample.size = 30,
seed = seed,
subset = .8)## Warning: `gather_()` was deprecated in tidyr 1.2.0.
## Please use `gather()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
# display the SHAP plot
SHAP_output[[1]]Visualize model performance by confusion matrix
The following code chunk uses eCM_plot function to visualize the confusion matrix for the test set in the model. More information can be found here: https://en.wikipedia.org/wiki/Confusion_matrix https://cran.r-project.org/web/packages/cvms/vignettes/Creating_a_confusion_matrix.html
# enhanced confusion matrix
eCM_plot(task = maintask,
trained_model = mylrn,
splits = splits)## Warning: 'tidy.table' is deprecated.
## Use 'tibble::as_tibble()' instead.
## See help("Deprecated")
## Warning in pal_name(palette, type): Unknown palette Green
## Warning: 'tidy.table' is deprecated.
## Use 'tibble::as_tibble()' instead.
## See help("Deprecated")
## Warning in pal_name(palette, type): Unknown palette Green


Decision curve analysis
The provided code chunk employs the eDecisionCurve function to conduct “decision curve analysis” on the test set within the model. For an in-depth understanding of this methodology, interested readers are encouraged to explore the following authoritative references:
Decision Curve Analysis: https://en.wikipedia.org/wiki/Decision_curve_analysis “Decision curve analysis: a novel method for evaluating prediction models” by Andrew J. Vickers and Elia B. Elkin. Link: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2577036/ These references offer comprehensive insights into the principles and applications of decision curve analysis, providing a solid foundation for further exploration and understanding of the methodology employed in the presented code.
# enhanced decision curve plot
eDecisionCurve(task = maintask,
trained_model = mylrn,
splits = splits,
seed = seed)Model evaluation (multi-metrics and visual inspection of ROC curves)
By running the next code chunk, you will get the following model evaluation metrics and visualizations:
AUC (Area Under the Curve): AUC quantifies the binary classification model’s performance by assessing the area under the ROC curve, which plots sensitivity against 1-specificity across various threshold values. A value of 0.5 suggests random chance performance, while 1 signifies perfect classification.
BACC (Balanced Accuracy): BACC addresses class imbalance by averaging sensitivity and specificity. Ranging from 0 to 1, a score of 0 indicates chance performance, and 1 signifies perfect classification.
MCC (Matthews Correlation Coefficient): MCC evaluates binary classification model quality, considering true positives, true negatives, false positives, and false negatives. Ranging from -1 to 1, -1 represents complete disagreement, 0 implies chance performance, and 1 indicates perfect classification.
BBRIER (Brier Score): BBRIER gauges the accuracy of probabilistic predictions by measuring the mean squared difference between predicted probabilities and true binary outcomes. Values range from 0 to 1, with 0 indicating perfect calibration and 1 indicating poor calibration.
PPV (Positive Predictive Value): PPV, or precision, measures the proportion of true positive predictions out of all positive predictions made by the model.
NPV (Negative Predictive Value): NPV quantifies the proportion of true negative predictions out of all negative predictions made by the model.
Specificity: Specificity calculates the proportion of true negative predictions out of all actual negative cases in a binary classification problem.
Sensitivity: Sensitivity, also known as recall or true positive rate, measures the proportion of true positive predictions out of all actual positive cases in a binary classification problem.
PRAUC (Precision-Recall Area Under the Curve): PRAUC assesses binary classification model performance based on precision and recall, quantifying the area under the precision-recall curve. A PRAUC value of 1 indicates perfect classification performance.
Additionally, the analysis involves the visualization of ROC and Precision-Recall curves for both development and test sets.
eROC_plot(task = maintask,
trained_model = mylrn,
splits = splits)## [[1]]

##
## [[2]]
## pred_results$score(measures = mlr3::msrs(meas))
## auc 1
## bacc 1
## mcc 1
## bbrier 0
## ppv 1
## npv 1
## specificity 1
## sensitivity 1
## prauc 1
##
## [[3]]
## pred_results_test$score(measures = mlr3::msrs(meas))
## auc 0.99
## bacc 0.98
## mcc 0.94
## bbrier 0.03
## ppv 0.94
## npv 0.99
## specificity 0.97
## sensitivity 0.99
## prauc 0.98
ROC curves with annotated thresholds
By running the next code chunk, you will get ROC and Precision Recall curves for the development and test sets this time with probability threshold information.
eperformance(task = maintask,
trained_model = mylrn,
splits = splits)## [[1]]

##
## [[2]]
##
## [[3]]
loading SHAP results for downstream analysis
Now we can get the outputs from eSHAP_plot function to apply clustering on SHAP values
shap_Mean_wide <- SHAP_output[[2]]
shap_Mean_long <- SHAP_output[[3]]
shap <- SHAP_output[[4]]SHAP values in association with feature values
SHAP_vs_FVAL_plt(shap_Mean_long)## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
partial dependence of features
Partial dependence plots (PDPs): PDPs can be used to visualize the marginal effect of a single feature on the model prediction.
PP_vs_FVAL_plt(shap_Mean_long = shap_Mean_long)## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
## Warning in geom2trace.default(dots[[1L]][[11L]], dots[[2L]][[1L]], dots[[3L]][[1L]]): geom_GeomTextNpc() has yet to be implemented in plotly.
## If you'd like to see this geom implemented,
## Please open an issue with your example code at
## https://github.com/ropensci/plotly/issues
extract feature values and predicted probabilities as output of the model to analyze
fval_predprob <- reshape2::dcast(shap, sample_num + pred_prob + predcorrectness ~ feature, value.var = "feature.value")Run a Shiny app to visualize 2-way partial dependence plots
## Warning: package 'shiny' was built under R version 4.1.3
##
## Attaching package: 'shiny'
## The following object is masked from 'package:cvms':
##
## validate
# assuming your data.table is named `fval_predprob`
ui <- fluidPage(
titlePanel("Feature Plot"),
sidebarLayout(
sidebarPanel(
selectInput("x_feature", "X-axis Feature:",
choices = names(fval_predprob)),
selectInput("y_feature", "Y-axis Feature:",
choices = names(fval_predprob)),
width = 2,
actionButton("stop_button", "Stop")
),
mainPanel(
plotOutput("feature_plot")
)
)
)
server <- function(input, output) {
# create a reactive value for the app state
app_state <- reactiveValues(running = TRUE)
output$feature_plot <- renderPlot({
ggplot(fval_predprob, aes(x = .data[[input$x_feature]], y = .data[[input$y_feature]], fill = pred_prob)) +
geom_tile() +
scale_fill_gradient(low = "white", high = "steelblue", limits = c(0, 1)) +
xlab(input$x_feature) +
ylab(input$y_feature) +
labs(shape = "correct prediction") +
egg::theme_article()
})
# observe the stop button
observeEvent(input$stop_button, {
app_state$running <- FALSE
})
# stop the app if the running state is FALSE
observe({
if(!app_state$running) {
stopApp()
}
})
}
shinyApp(ui = ui, server = server)## PhantomJS not found. You can install it with webshot::install_phantomjs(). If it is installed, please make sure the phantomjs executable can be found via the PATH variable.
# This is the data table that includes SHAP values in wide format. Each row contains sample_num as sample ID and the SHAP values for each feature.
shap_Mean_wide## sample_num Bare.nuclei Bl.cromatin Cell.shape Cell.size Cl.thickness
## 1: 1 -0.04253333 -0.02653333 -0.04220000 -0.04840000 -0.001200000
## 2: 2 -0.02833333 -0.02566667 -0.03960000 -0.05206667 -0.023866667
## 3: 3 0.10960000 0.06413333 0.15053333 0.16226667 0.091533333
## 4: 4 -0.04093333 -0.02233333 -0.03926667 -0.04906667 -0.019800000
## 5: 5 0.11146667 0.06286667 0.14373333 0.16266667 0.092466667
## ---
## 177: 177 0.05420000 0.00460000 0.16913333 0.17960000 0.129533333
## 178: 178 -0.03653333 -0.02366667 -0.04186667 -0.04813333 -0.020200000
## 179: 179 -0.01233333 0.07166667 0.17593333 0.18326667 0.107600000
## 180: 180 0.10073333 0.06220000 0.14433333 0.14613333 0.098800000
## 181: 181 -0.04646667 -0.03060000 -0.04333333 -0.05633333 -0.002533333
## Epith.c.size Marg.adhesion Mitoses Normal.nucleoli age
## 1: -0.03013333 -0.017666667 -0.0070666667 -0.04053333 -6.000000e-04
## 2: -0.02900000 -0.016333333 -0.0068666667 -0.03880000 -3.333333e-04
## 3: 0.01546667 0.073800000 -0.0002000000 0.06560000 7.000000e-03
## 4: -0.02473333 -0.017200000 -0.0066000000 -0.03746667 -6.000000e-04
## 5: 0.01986667 0.044200000 0.0002666667 0.09533333 8.533333e-03
## ---
## 177: 0.04873333 0.024200000 -0.0002000000 0.08946667 6.733333e-03
## 178: -0.02646667 -0.014600000 -0.0074666667 -0.04053333 -6.666667e-05
## 179: 0.05613333 0.048133333 0.0006666667 0.09773333 4.400000e-03
## 180: 0.03260000 0.044466667 0.0024666667 0.10200000 5.800000e-03
## 181: -0.02926667 -0.001266667 -0.0078000000 -0.03520000 -3.133333e-03
## sex
## 1: 0.0004666667
## 2: 0.0004666667
## 3: -0.0021333333
## 4: -0.0024000000
## 5: -0.0018000000
## ---
## 177: -0.0024000000
## 178: -0.0008666667
## 179: 0.0024000000
## 180: -0.0019333333
## 181: 0.0015333333
# This is the data table of SHAP values in long format and includes feature name, mean_phi_test: mean shap value for the feature across samples, scaled feature values, shap values for each feature, sample ID, and whether the prediction was correct (i.e. predicted class = actual class)
shap_Mean_long## feature mean_phi Phi f_val sample_num correct_prediction
## 1: Bare.nuclei 0.056288636 -0.0425333333 0 1 Correct
## 2: Bare.nuclei 0.056288636 -0.0283333333 0 2 Correct
## 3: Bare.nuclei 0.056288636 0.1096000000 1 3 Correct
## 4: Bare.nuclei 0.056288636 -0.0409333333 0 4 Correct
## 5: Bare.nuclei 0.056288636 0.1114666667 1 5 Correct
## ---
## 1987: sex 0.001761364 -0.0024000000 0 177 Correct
## 1988: sex 0.001761364 -0.0008666667 0 178 Correct
## 1989: sex 0.001761364 0.0024000000 1 179 Correct
## 1990: sex 0.001761364 -0.0019333333 0 180 Correct
## 1991: sex 0.001761364 0.0015333333 1 181 Correct
## pred_prob
## 1: 0.996
## 2: 1.000
## 3: 0.002
## 4: 1.000
## 5: 0.000
## ---
## 1987: 0.036
## 1988: 1.000
## 1989: 0.004
## 1990: 0.002
## 1991: 0.994
Patient subgroups determined by SHAP clusters
SHAP clustering is a method to better understand why a model may perform better for some patients than others. Here for example, you can identify patient subgroups that have specific patterns that are different from other subgroups and that explains why the model you developed have perhaps better or worse performance for those patients than the average performance for the whole dataset. You can see the difference in the SHAP plots that if you group all together provide the overall SHAP summary plot. Again here the edges reflect how features may interact with each other in each individual sample (instance).
# the number of clusters can be changed
SHAP_plot_clusters <- SHAPclust(task = maintask,
trained_model = mylrn,
splits = splits,
shap_Mean_wide = shap_Mean_wide,
shap_Mean_long = shap_Mean_long,
num_of_clusters = 4,
seed = seed,
subset = .8)## Saving 200 x 114 mm image
## Saving 200 x 114 mm image
## Saving 7.29 x 4.51 in image
## Saving 7.29 x 4.51 in image
# note that the subset must be the same value as the SHAP analysis done earlier
# display the SHAP cluster plots
SHAP_plot_clusters[[1]]
# display the confusion matrices corresponding to the SHAP clusters (patient subsets determined by SHAP clusters)
SHAP_plot_clusters[[2]]
Model fairness (sensitivity analysis)
Sometimes we would like to investigate whether our model performs fairly well for different subgroups based on categories of variables such as sex.
# you should decide what variables to use to be tested
# here we chose sex from the variables existing in the dataset
Fairness_results <- eFairness(task = maintask,
trained_model = mylrn,
splits = splits,
target_variable = "sex",
var_levels = c("Male", "Female"))## Warning in verify_d(data$d): D not labeled 0/1, assuming benign = 0 and
## malignant = 1!
## Warning in verify_d(data$d): D not labeled 0/1, assuming benign = 0 and
## malignant = 1!
## Warning in verify_d(data$d): D not labeled 0/1, assuming benign = 0 and
## malignant = 1!
# ROC curves for the subgroups for the development (left) and test (right) sets
Fairness_results[[1]]
# performance in the subgroups for the development set
Fairness_results[[2]]## Male Female
## auc 1 1
## bacc 1 1
## mcc 1 1
## bbrier 0 0
## ppv 1 1
## npv 1 1
## specificity 1 1
## sensitivity 1 1
## prauc 1 1
# performance in the subgroups for the test set
Fairness_results[[3]]## Male Female
## auc 0.99 0.99
## bacc 0.99 0.97
## mcc 0.96 0.93
## bbrier 0.03 0.03
## ppv 0.94 0.94
## npv 1.00 0.99
## specificity 0.97 0.96
## sensitivity 1.00 0.98
## prauc 0.98 0.99
Model parameters
# get model parameters
model_params <- mylrn$param_set
print(as.data.table(model_params))## id class lower upper levels nlevels
## 1: ntree ParamInt 1 Inf Inf
## 2: mtry ParamInt 1 Inf Inf
## 3: replace ParamLgl NA NA TRUE,FALSE 2
## 4: classwt ParamUty NA NA Inf
## 5: cutoff ParamUty NA NA Inf
## 6: strata ParamUty NA NA Inf
## 7: sampsize ParamUty NA NA Inf
## 8: nodesize ParamInt 1 Inf Inf
## 9: maxnodes ParamInt 1 Inf Inf
## 10: importance ParamFct NA NA accuracy,gini,none,FALSE 4
## 11: localImp ParamLgl NA NA TRUE,FALSE 2
## 12: proximity ParamLgl NA NA TRUE,FALSE 2
## 13: oob.prox ParamLgl NA NA TRUE,FALSE 2
## 14: norm.votes ParamLgl NA NA TRUE,FALSE 2
## 15: do.trace ParamLgl NA NA TRUE,FALSE 2
## 16: keep.forest ParamLgl NA NA TRUE,FALSE 2
## 17: keep.inbag ParamLgl NA NA TRUE,FALSE 2
## 18: predict.all ParamLgl NA NA TRUE,FALSE 2
## 19: nodes ParamLgl NA NA TRUE,FALSE 2
## is_bounded special_vals default storage_type tags
## 1: FALSE <list[0]> 500 integer train,predict
## 2: FALSE <list[0]> <NoDefault[3]> integer train
## 3: TRUE <list[0]> TRUE logical train
## 4: FALSE <list[0]> list train
## 5: FALSE <list[0]> <NoDefault[3]> list train
## 6: FALSE <list[0]> <NoDefault[3]> list train
## 7: FALSE <list[0]> <NoDefault[3]> list train
## 8: FALSE <list[0]> 1 integer train
## 9: FALSE <list[0]> <NoDefault[3]> integer train
## 10: TRUE <list[1]> FALSE character train
## 11: TRUE <list[0]> FALSE logical train
## 12: TRUE <list[0]> FALSE logical train,predict
## 13: TRUE <list[0]> <NoDefault[3]> logical train
## 14: TRUE <list[0]> TRUE logical train
## 15: TRUE <list[0]> FALSE logical train
## 16: TRUE <list[0]> TRUE logical train
## 17: TRUE <list[0]> FALSE logical train
## 18: TRUE <list[0]> FALSE logical predict
## 19: TRUE <list[0]> FALSE logical predict
Report packages that have been used
The current R package has been developed with the following dependencies, ensuring its functionality and compatibility: For a detailed reference on each package, please consult the respective documentation and citation information available on CRAN (Comprehensive R Archive Network) or the official package repositories.
## Warning: package 'report' was built under R version 4.1.3
oursessionreport <- report(sessionInfo())## Warning in utils::citation(pkg_name): no date field in DESCRIPTION file of
## package 'mlr3extralearners'
## Warning in utils::citation(pkg_name): no date field in DESCRIPTION file of
## package 'explainer'
## Warning in utils::citation(pkg_name): could not determine year for 'explainer'
## from package DESCRIPTION file
summary(oursessionreport)## The analysis was done using the R Statistical language (v4.1.2; R Core Team,
## 2021) on Windows 10 x64, using the packages ggpubr (v0.4.0), gridExtra (v2.3),
## egg (v0.4.5), plotly (v4.10.0), cowplot (v1.1.1), broom (v1.0.0), ggplot2
## (v3.3.6), reshape2 (v1.4.4), stringr (v1.4.0), forcats (v0.5.1), tidyr
## (v1.2.0), readr (v2.1.2), dplyr (v1.0.9), writexl (v1.4.0), tibble (v3.1.8),
## mlr3 (v0.14.0), purrr (v0.3.4), cvms (v1.3.4), report (v0.5.1), data.table
## (v1.14.2), plotROC (v2.3.0), mlr3viz (v0.5.9), mlr3learners (v0.5.3), iml
## (v0.11.0), ggpmisc (v0.4.7), ggpp (v0.4.4), explainer (v1.0.0),
## mlr3extralearners (v0.5.18), psych (v2.2.5), magrittr (v2.0.3), tidyverse
## (v1.3.2), shiny (v1.7.4) and knitr (v1.39).
oursessionreport_df <- as.data.frame(oursessionreport)
fwrite(oursessionreport_df, file = paste0("sessioninfo",seed,".xlsx"))
summary(as.data.frame(oursessionreport))## Package | Version
## ---------------------------
## broom | 1.0.0
## cowplot | 1.1.1
## cvms | 1.3.4
## data.table | 1.14.2
## dplyr | 1.0.9
## egg | 0.4.5
## explainer | 1.0.0
## forcats | 0.5.1
## ggplot2 | 3.3.6
## ggpmisc | 0.4.7
## ggpp | 0.4.4
## ggpubr | 0.4.0
## gridExtra | 2.3
## iml | 0.11.0
## knitr | 1.39
## magrittr | 2.0.3
## mlr3 | 0.14.0
## mlr3extralearners | 0.5.18
## mlr3learners | 0.5.3
## mlr3viz | 0.5.9
## plotly | 4.10.0
## plotROC | 2.3.0
## psych | 2.2.5
## purrr | 0.3.4
## R | 4.1.2
## readr | 2.1.2
## report | 0.5.1
## reshape2 | 1.4.4
## shiny | 1.7.4
## stringr | 1.4.0
## tibble | 3.1.8
## tidyr | 1.2.0
## tidyverse | 1.3.2
## writexl | 1.4.0